import utils.config as config
import numpy as np
from pypower.makeYbus import makeYbus
import torch
from pypower.ext2int import ext2int1
from pypower.idx_brch import BR_STATUS
from pypower.idx_gen import GEN_STATUS,PMIN,PMAX



def get_relative_error(real, predict):
    '''
    relative error
    '''
    if len(real) == len(predict):
        err = (np.sum(predict) - np.sum(real)) / (np.sum(real)) * 100
        return err
    else:
        return None

def get_load_mismatch_rate(real, predict):
    '''
    absolute relative error
    '''
    if len(real) == len(predict):
        err = abs((real - predict) / real) * 100
        return err
    else:
        return None

## other function
def get_clamp(Pred, Predmin, Predmax):
    # each row is a sample;Predmin and Predmax is the limit for each element of each row
    Pred_clip = Pred.clone()
    for i in range(Pred.shape[1]):
        Pred_clip[:, i] = Pred_clip[:, i].clamp(min=Predmin[i])
        Pred_clip[:, i] = Pred_clip[:, i].clamp(max=Predmax[i])
    return Pred_clip

def get_abs_error(real, predict):
    '''
    mean absolute error
    '''
    if len(real) == len(predict):
        err = torch.mean(torch.abs(real - predict))
        return err
    else:
        return None

def get_genload_N1_back(V, Pdtest, Qdtest, bus_Pg, bus_Qg):
    S = np.zeros(V.shape, dtype='complex_')
    branch_id = config.branch_test_id

    # 保存每个场景的发电机状态
    gen_status_list = []  # 每个场景的发电机在线状态

    for i in range(V.shape[0]):
        branch_test = config.branch.copy()
        gen_test = config.gen.copy()

        # Handle branch outage
        if branch_id[i] > 0:
            branch_test[branch_id[i] - 1, BR_STATUS] = 0  # Set branch to offline

        # Handle generator outage
        elif branch_id[i] < 0:
            gen_index = -branch_id[i] - 1  # Convert to generator index (0-based)
            gen_test[gen_index, GEN_STATUS] = 0  # Set generator to offline
            gen_test[gen_index, PMIN] = 0  # 强制出力为0
            gen_test[gen_index, PMAX] = 0  # 强制出力为0

        # Update internal configuration
        _, _, gen_test, branch_test = ext2int1(config.bus, gen_test, branch_test)

        # Recalculate Ybus, Yf, Yt matrices with updated branch and generator states
        Ybus, Yf, Yt = makeYbus(config.baseMVA, config.bus, branch_test)

        # Calculate current injection I and power injection S
        I = Ybus.dot(V[i]).conj()
        S[i] = np.multiply(V[i], I)  # S = V * I*

        # Save generator status for this scenario
        gen_status_list.append(gen_test[:, GEN_STATUS])  # 保存每次场景的发电机状态

    # Split real and reactive power
    P = np.real(S)
    Q = np.imag(S)

    # Initialize Pg, Qg
    Pg = np.zeros((V.shape[0], len(bus_Pg)))  # 每个场景的发电机有独立的Pg
    Qg = np.zeros((V.shape[0], len(bus_Qg)))

    # 计算每个场景的发电机Pg和Qg
    for i, gen_status in enumerate(gen_status_list):
        for j, status in enumerate(gen_status):
            if status == 1:  # 如果发电机在线
                Pg[i, j] = P[i, bus_Pg[j]] + Pdtest[i, bus_Pg[j]]
                Qg[i, j] = Q[i, bus_Qg[j]] + Qdtest[i, bus_Qg[j]]

    # 计算负载Pd和Qd（批量计算，循环外完成）
    Pd = -P*1.0
    Qd = -Q*1.0

    # 批量操作，更新发电机节点的负载功率
    Pd[:, bus_Pg] = Pg - P[:, bus_Pg]
    Qd[:, bus_Qg] = Qg - Q[:, bus_Qg]

    return Pg, Qg, Pd, Qd


def get_genload_N1(V, Pdtest, Qdtest, bus_Pg, bus_Qg):
    S = np.zeros(V.shape, dtype='complex_')
    branch_id = config.branch_test_id
    for i in range(V.shape[0]):
        branch_test = config.branch.copy()

        # Handle branch outage
        if branch_id[i] > 0:
            branch_test[branch_id[i] - 1, BR_STATUS] = 0  # Set branch to offline

        # Update internal configuration
        _, _, _, branch_test = ext2int1(config.bus, config.gen, branch_test)

        # Recalculate Ybus, Yf, Yt matrices with updated branch and generator states
        Ybus, Yf, Yt = makeYbus(config.baseMVA, config.bus, branch_test)

        # Calculate current injection I and power injection S
        I = Ybus.dot(V[i]).conj()
        S[i] = np.multiply(V[i], I)  # S = V * I*
        

    # Split real and reactive power
    P = np.real(S)
    Q = np.imag(S)

    Pg = P[:, bus_Pg] + Pdtest[:, bus_Pg]
    Qg = Q[:, bus_Qg] + Qdtest[:, bus_Qg]

    Pd = -P*1.0
    Qd = -Q*1.0

    # 批量操作，更新发电机节点的负载功率
    Pd[:, bus_Pg] = Pg - P[:, bus_Pg]
    Qd[:, bus_Qg] = Qg - Q[:, bus_Qg]

    return Pg, Qg, Pd, Qd

# cost
def get_Pgcost(Pg, idxPg, gencost):
    cost = np.zeros(Pg.shape[0])
    PgMVA = Pg * config.baseMVA
    for i in range(Pg.shape[0]):
        c1 = np.multiply(gencost[idxPg, 4], np.multiply(PgMVA[i, :], PgMVA[i, :]))   # quadratic term
        c2 = np.multiply(gencost[idxPg, 5], PgMVA[i, :])   # linear term
        c3 = gencost[idxPg, 6]  # constant term
        cost[i] = np.sum(c1 + c2 + c3)

    return cost


def evs_score_matrix(y_true, y_pred, verbose=True):
    """
    计算平均解释方差分数(EVS)，并统计不同区间的分布情况
    
    参数:
    y_true (numpy.ndarray): 真实值矩阵，形状为 (样本数, 特征数)
    y_pred (numpy.ndarray): 预测值矩阵，形状为 (样本数, 特征数)
    verbose (bool): 是否输出详细分布统计和极端数据
    
    返回:
    tuple: (平均EVS分数, 分布统计字典)
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("输入的矩阵形状不一致！")
    
    # 计算每个样本的方差
    total_variance = np.var(y_true, axis=1)
    residual_variance = np.var(y_true - y_pred, axis=1)
    
    # 计算每个样本的EVS
    with np.errstate(divide='ignore', invalid='ignore'):
        evs_per_sample = 1 - residual_variance / total_variance
        evs_per_sample[total_variance == 0] = 0
    
    # 定义分布区间边界
    bins = [-np.inf, 0, 0.5, 0.9, 1, np.inf]
    bin_labels = ["<0", "0-0.5", "0.5-0.9", "0.9-1", ">1"]
    
    # 统计各区间的样本数量
    counts, _ = np.histogram(evs_per_sample, bins=bins)
    total_samples = len(evs_per_sample)
    distribution = {
        label: {"count": count, "percentage": count/total_samples*100}
        for label, count in zip(bin_labels, counts)
    }
    
    # 识别极端数据（EVS < 0）
    extreme_mask = (evs_per_sample < 0)
    extreme_indices = np.where(extreme_mask)[0]
    extreme_values = evs_per_sample[extreme_mask]
    extreme_residual_var = residual_variance[extreme_mask]
    extreme_total_var = total_variance[extreme_mask]
    
    # 输出统计结果
    if verbose:
        print("\nEVS分数分布统计:")
        print("=" * 40)
        for label, stats in distribution.items():
            print(f"{label}: {stats['count']}个样本 ({stats['percentage']:.2f}%)")
        
        print("\n极端数据统计 (EVS < 0):")
        print("=" * 40)
        print(f"共发现 {len(extreme_values)} 个极端样本 ({len(extreme_values)/total_samples*100:.2f}%)")
        
        # 打印前100个极端样本的详细信息
        print("\n前100个极端样本详情 [索引: EVS值, 残差方差, 总方差]:")
        print("=" * 80)
        
        # 确保不超出实际数量限制
        n_extreme = min(100, len(extreme_indices))
        for i in range(n_extreme):
            idx = extreme_indices[i]
            evs_val = extreme_values[i]
            res_var = extreme_residual_var[i]
            tot_var = extreme_total_var[i]
            print(f"样本 {idx}: EVS={evs_val:.6f}, res_var={res_var:.6f}, tot_var={tot_var:.6f}")
        
        if len(extreme_indices) > 100:
            print(f"\n...已省略{len(extreme_indices)-100}个极端样本")
    
    return np.mean(evs_per_sample), distribution